import os, json
import numpy as np
import pandas as pd
from scipy import stats
from pathlib import Path
from typing import Dict, List
from sklearn.preprocessing import LabelEncoder

## Load and preprocess raw multiclass classification data.
def load_dataset(data: str) -> Dict:
    data_path = Path(f"data/{data}")
    le = LabelEncoder()
    if data == "optdigits":
        data_ = np.r_[
            np.loadtxt(data_path / f"{data}.tra", delimiter=","),
            np.loadtxt(data_path / f"{data}.tes", delimiter=","),
        ]
    elif data == "pendigits":
        data_ = np.r_[
            np.loadtxt(data_path / f"{data}.tra", delimiter=","),
            np.loadtxt(data_path / f"{data}.tes", delimiter=","),
        ]
    elif data == "sat":
        data_ = np.r_[
            np.loadtxt(data_path / f"{data}.trn", delimiter=" "),
            np.loadtxt(data_path / f"{data}.tst", delimiter=" "),
        ]
        data_[:, -1] = np.where(data_[:, -1] == 7, 5, data_[:, -1] - 1)
    elif data == "letter":
        data_ = np.genfromtxt(
            data_path / "letter-recognition.data", delimiter=",", dtype="str"
        )
        data_ = np.c_[data_[:, 1:], le.fit_transform(data_[:, 0])].astype(float)

    X = data_[:, :-1]
    y = data_[:, -1].astype(int)
    n_data = data_.shape[0]
    n_class = np.unique(y).shape[0]

    return dict(
        n_data=n_data,
        n_class=n_class,
        X=X,
        y=y
    )

## Load state action data.
def load_state_action_data(dataset: str):
    # Load state and action data for a given dataset
    state_data_df = pd.read_parquet(f'action_data/{dataset}/state_data.parquet')

    # Derive counts without relying on DataFrame.attrs (not preserved in parquet)
    num_features = state_data_df.shape[1] - 1 if 'y' in state_data_df.columns else state_data_df.shape[1]

    action_data_df = pd.read_parquet(f'action_data/{dataset}/action_data.parquet')
    if 'y' in state_data_df.columns:
        num_actions = int(state_data_df['y'].max()) + 1
    else:
        max_y_gt = int(action_data_df['y_gt'].apply(np.max).max())
        max_y_eval = int(action_data_df['y_eval_pol'].apply(np.max).max())
        num_actions = max(max_y_gt, max_y_eval) + 1

    return state_data_df, action_data_df, num_features, num_actions

## Read job configs action data.
def read_jsonl_file(filepath):
    data = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                json_object = json.loads(line.strip())
                data.append(json_object)
            except json.JSONDecodeError as e:
                print(f"Error decoding JSON on line: {line.strip()}. Error: {e}")
    return data

## Compute Relative RMSEs
def compute_rel_rmse(ope_results: List):
    rel_rmse = {}
    n_exp = len(ope_results)
    for ope_result in ope_results:
        J_gt = ope_result['J_gt']
        for key, val in ope_result.items():
            if key != 'J_gt' and not key.startswith('time_'):
                if key not in rel_rmse.keys():
                    rel_rmse[key]=(J_gt - val)**2
                else:
                    rel_rmse[key]+=(J_gt - val)**2
    for key in rel_rmse.keys():
        rel_rmse[key] = np.sqrt(rel_rmse[key] / n_exp)/J_gt
    return rel_rmse

## Save Relative RMSEs
def save_rel_rmses(job_name: str, config: Dict, rel_rmses: Dict):
    estimate_pi_b = config['estimate_pi_b']
    num_loggers = config['num_loggers']
    n_fold = config['n_fold']
    exploration_probs = config['exploration_probs']
    stratum_ratio = config['stratum_ratio']

    output_dir = os.path.join("log", job_name)
    os.makedirs(output_dir, exist_ok=True)

    def _format_list_for_filename(values):
        arr = np.asarray(values, dtype=float).ravel().tolist()
        def fmt(x):
            s = f"{x:.6f}".rstrip("0").rstrip(".")
            return s if s else "0"
        return "-".join(fmt(v) for v in arr)

    for dataset, rel_rmse in rel_rmses.items():
        fname = (
            "rel_rmse__"
            f"dataset={dataset}__"
            f"estimate_pi_b={'true' if estimate_pi_b else 'false'}__"
            f"num_loggers={num_loggers}__"
            f"n_fold={n_fold}__"
            f"exploration_probs={_format_list_for_filename(exploration_probs)}__"
            f"stratum_ratio={_format_list_for_filename(stratum_ratio)}.csv"
        )
        fpath = os.path.join(output_dir, fname)
        df = pd.DataFrame([{ 'method': k, 'rel_rmse': float(v) } for k, v in rel_rmse.items()])
        try:
            df.to_csv(fpath, index=False)
            print(f"    Rel-RMSE Saved: {fpath}")
        except OSError as e:
            print(f"    OS Error occured in saving OPE results! (Maybe the filename is too long?)")

## Compute mean and 95% CI
def compute_mean_ci(input: np.ndarray):
    input = np.asarray(input)
    mean = np.mean(input)
    sem = stats.sem(input)  # standard error of the mean
    ci_low, ci_high = stats.norm.interval(0.95, loc=mean, scale=sem)
    return mean, (ci_low, ci_high)